Skip to content

Conversation

@meshtag
Copy link
Member

@meshtag meshtag commented Dec 23, 2025

Added rank‑0 handling to masked load/store emulation by reinterpreting rank‑0 memrefs as 1‑D buffers with a synthetic index, preventing empty‑indices crashes.

Fixes #131243

Added rank‑0 handling to masked load/store emulation by reinterpreting rank‑0
memrefs as 1‑D buffers with a synthetic index, preventing empty‑indices
crashes.
@llvmbot
Copy link
Member

llvmbot commented Dec 23, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Prathamesh Tagore (meshtag)

Changes

Added rank‑0 handling to masked load/store emulation by reinterpreting rank‑0 memrefs as 1‑D buffers with a synthetic index, preventing empty‑indices crashes.

Fixes #131243


Full diff: https://github.com/llvm/llvm-project/pull/173325.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp (+22)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir (+18)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
index 7acc120508a44..cfd478b27908b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
@@ -19,6 +19,26 @@ using namespace mlir;
 
 namespace {
 
+/// Ensure that `base` has at least one index by reinterpreting rank-0 memrefs
+/// as 1-D memrefs. This avoids crashing on rank-0 memrefs for the pass.
+static void ensureBaseHasIndex(PatternRewriter &rewriter, Location loc,
+                               Value &base, SmallVectorImpl<Value> &indices,
+                               int64_t maskLength) {
+  if (!indices.empty())
+    return;
+
+  // Rank-0 memrefs have no indices, reinterpret as 1-D to step through lanes.
+  auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, base);
+  SmallVector<OpFoldResult, 1> sizes = {rewriter.getIndexAttr(maskLength)};
+  SmallVector<OpFoldResult, 1> strides = {rewriter.getIndexAttr(1)};
+  base = memref::ReinterpretCastOp::create(rewriter, loc, meta.getBaseBuffer(),
+                                           meta.getOffset(), sizes, strides);
+
+  Type indexType = rewriter.getIndexType();
+  indices.push_back(arith::ConstantOp::create(rewriter, loc, indexType,
+                                              IntegerAttr::get(indexType, 0)));
+}
+
 /// Convert vector.maskedload
 ///
 /// Before:
@@ -65,6 +85,7 @@ struct VectorMaskedLoadOpConverter final
     Value base = maskedLoadOp.getBase();
     Value iValue = maskedLoadOp.getPassThru();
     auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
+    ensureBaseHasIndex(rewriter, loc, base, indices, maskLength);
     Value one = arith::ConstantOp::create(rewriter, loc, indexType,
                                           IntegerAttr::get(indexType, 1));
     for (int64_t i = 0; i < maskLength; ++i) {
@@ -135,6 +156,7 @@ struct VectorMaskedStoreOpConverter final
     Value value = maskedStoreOp.getValueToStore();
     bool nontemporal = false;
     auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
+    ensureBaseHasIndex(rewriter, loc, base, indices, maskLength);
     Value one = arith::ConstantOp::create(rewriter, loc, indexType,
                                           IntegerAttr::get(indexType, 1));
     for (int64_t i = 0; i < maskLength; ++i) {
diff --git a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
index 6e5d68c859e2c..8bace2ca9875b 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
@@ -123,3 +123,21 @@ func.func @vector_maskedstore_with_alignment(%arg0 : memref<4x5xf32>, %arg1 : ve
   vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 { alignment = 8 } : memref<4x5xf32>, vector<4xi1>, vector<4xf32>
   return
 }
+
+// CHECK-LABEL:  @vector_maskedstore_rank0
+// CHECK-SAME:   (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xi1>, %[[ARG4:.*]]: vector<1xf32>) {
+// CHECK-DAG:    %[[C0:.*]] = arith.constant 0 : index
+// CHECK-NEXT:   %[[SUBVIEW:.*]] = memref.subview %[[ARG0]]{{\[}}%[[ARG1]], %[[ARG2]]] [1, 1] [1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+// CHECK-NEXT:   %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[SUBVIEW]] : memref<f32, strided<[], offset: ?>> -> memref<f32>, index
+// CHECK-NEXT:   %[[REINT:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [1], strides: [1] : memref<f32> to memref<1xf32, strided<[1], offset: ?>>
+// CHECK-NEXT:   %[[M0:.*]] = vector.extract %[[ARG3]][0] : i1 from vector<1xi1>
+// CHECK-NEXT:   scf.if %[[M0]] {
+// CHECK-NEXT:     %[[V0:.*]] = vector.extract %[[ARG4]][0] : f32 from vector<1xf32>
+// CHECK-NEXT:     memref.store %[[V0]], %[[REINT]][%[[C0]]] : memref<1xf32, strided<[1], offset: ?>>
+func.func @vector_maskedstore_rank0(%arg0: memref<12x32xf32>, %arg1: index,
+                                   %arg2: index, %arg3: vector<1xi1>,
+                                   %arg4: vector<1xf32>) {
+  %subview = memref.subview %arg0[%arg1, %arg2] [1, 1] [1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+  vector.maskedstore %subview[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<1xi1>, vector<1xf32>
+  return
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[MLIR] [Vector] optimization VectorEmulateMaskedLoadStore crashes

2 participants